CK Tile Group GEMM gfx1250#576
Conversation
| // Currently only support cutlass group gemm on Hopper Arch | ||
| if (!(is_hopper && use_cutlass)) { | ||
| //if (!(is_hopper && use_cutlass)) { | ||
| if (!use_cutlass) { |
| using type = TileCfg_256x256x64_WMMA; | ||
| }; | ||
|
|
||
| template <GPUArch Arch> |
There was a problem hiding this comment.
Why does it need template over reguler if-else or switch-case?
There was a problem hiding this comment.
The template is needed because the arch selection affects CK kernel template instantiation, not just runtime control flow. GPUArch must be a compile-time value so if constexpr can prune unsupported tile/kernel combinations for a given architecture. In this case, it prevents the MFMA configs from being instantiated for gfx1250.
There was a problem hiding this comment.
I didn't compile it with gfx1250 arch only but I was still puzzled about this templated dispatch. In line 298, you still rely on runtime detect_gpu_arch() to branch to specific ck_tile_grouped_gemm_fp16_dispatch_arch<arch_id>'s. So I presume all three arches verions will still be instantiated? And I didn't see any compile time guarding?
There was a problem hiding this comment.
Good point. Before the latest change the runtime switch still referenced all ck_tile_grouped_gemm_fp16_dispatch_arch<...> specializations, so the compiler could still instantiate all arch variants. b55fe29 adds compile-time #if defined(__gfxXXX__) guards around each runtime dispatch case so unsupported arch paths are no longer instantiated.
| static constexpr ck_tile::index_t M_Warp_Tile = 16; | ||
| static constexpr ck_tile::index_t N_Warp_Tile = 16; | ||
| static constexpr ck_tile::index_t K_Warp_Tile = 32; | ||
|
|
||
| static constexpr bool kPadM = true; | ||
| static constexpr bool kPadN = true; | ||
| static constexpr bool kPadK = true; |
There was a problem hiding this comment.
so the difference btw TileCfg_256x256x64_MFMA and TileCfg_256x256x64_WMMA is inside M, N, K warp tile and kPads?
There was a problem hiding this comment.
The difference is not just the warp tile shape or kPads. MFMA and WMMA are different warp-level MMA instruction paths, so they lower through different warp dispatch/pipeline configurations with different tile and padding requirements.
There was a problem hiding this comment.
Oh, I was comparing those two struct classes TileCfg_256x256x64_MFMA and TileCfg_256x256x64_WMMA. Inside those two defined structs, contents just differ by warp tile and kPads?
| using type = TileCfg_256x256x64_WMMA; | ||
| }; | ||
|
|
||
| template <GPUArch Arch> |
There was a problem hiding this comment.
I didn't compile it with gfx1250 arch only but I was still puzzled about this templated dispatch. In line 298, you still rely on runtime detect_gpu_arch() to branch to specific ck_tile_grouped_gemm_fp16_dispatch_arch<arch_id>'s. So I presume all three arches verions will still be instantiated? And I didn't see any compile time guarding?
| COMPILE_OPTIONS "-g0;-dopt=on") | ||
| else() | ||
| set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) | ||
| set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/rocm_libraries/projects/composablekernel) |
There was a problem hiding this comment.
nit: Will the whole rocm_libraries too big? Do we have a way to have sparse check out for this ck subdir?
There was a problem hiding this comment.
Good point. The full rocm_libraries checkout is fairly large (~8G locally), while projects/composablekernel alone is much smaller (~167M). Yeah, sparse checkout probably makes sense here, but I am wondering if it would be better handled in a separate PR.
There was a problem hiding this comment.
Yeah, we can do it in separate PR
…emm-gfx1250-clean
Description
Extend the present CK tile grouped GEMM (F16/F8) implementation for compatibility with gfx1250. Replaces 3rdparty/aiter with 3rdparty/rocm-libraries for the gfx1250 changes from CK.
Fixes #16490
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: